{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "# 自定义 TVM 数据类型\n",
        "\n",
        "**原作者**: [Gus Smith](https://github.com/gussmith23), [Andrew Liu](https://github.com/hypercubestart)\n",
        "\n",
        "在本教程中，将向您展示如何利用 Bring Your Own Datatypes 框架在 TVM 中使用您自己的自定义数据类型。请注意，Bring Your Own Datatypes 框架目前只处理 **software emulated versions of datatypes**。框架不支持开箱即用的自定义加速器数据类型的编译。\n",
        "\n",
        "## Datatype 库\n",
        "\n",
        "Bring Your Own Datatypes 允许用户在 TVM 的原生数据类型(如 ``float`` 下)旁边注册自己的数据类型实现。在一般情况下，这些数据类型实现通常以库的形式出现。例如：\n",
        "\n",
        "- [libposit](https://github.com/cjdelisle/libposit), a posit library\n",
        "- [Stillwater Universal](https://github.com/stillwater-sc/universal), a library with posits, fixed-point numbers, and other types\n",
        "- [SoftFloat](https://github.com/ucb-bar/berkeley-softfloat-3), Berkeley's software implementation of IEEE 754 floating-point\n",
        "\n",
        "在本节中，我们将使用一个已经实现的示例库，位于 ``3rdparty/byodt/myfloat.cc``。我们称之为“myfloat”的这个数据类型实际上只是 IEE-754  浮点数，但它提供了有用的示例\n",
        "以说明在 BYODT 框架中可以使用任何数据类型。\n",
        "\n",
        "## 设置\n",
        "\n",
        "因为我们不使用任何 3rdparty 库，所以不需要设置。\n",
        "\n",
        "如果你想在自己的数据类型库中尝试这种方法，首先使用 ``CDLL`` 将库的函数引入进程空间：\n",
        "\n",
        "```python\n",
        "ctypes.CDLL('my-datatype-lib.so', ctypes.RTLD_GLOBAL)\n",
        "```\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## 简单的 TVM 程序\n",
        "\n",
        "我们将首先在 TVM 中编写一个简单的程序；然后，重写它以使用自定义数据类型。\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import tvm\n",
        "from tvm import relay\n",
        "\n",
        "# 我们的基本程序:  Z = X + Y\n",
        "x = relay.var(\"x\", shape=(3,), dtype=\"float32\")\n",
        "y = relay.var(\"y\", shape=(3,), dtype=\"float32\")\n",
        "z = x + y\n",
        "program = relay.Function([x, y], z)\n",
        "module = tvm.IRModule.from_expr(program)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "现在，我们使用 numpy 创建随机数输入到这个程序中："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {
        "collapsed": false
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "x: [0.51729786 0.9469626  0.7654598 ]\n",
            "y: [0.28239584 0.22104536 0.6862221 ]\n"
          ]
        }
      ],
      "source": [
        "import numpy as np\n",
        "\n",
        "np.random.seed(23)  # 为再现性\n",
        "\n",
        "x_input = np.random.rand(3).astype(\"float32\")\n",
        "y_input = np.random.rand(3).astype(\"float32\")\n",
        "print(f\"x: {x_input}\")\n",
        "print(f\"y: {y_input}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "最后，我们准备好运行程序："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {
        "collapsed": false
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "z: [0.7996937 1.168008  1.4516819]\n"
          ]
        }
      ],
      "source": [
        "z_output = relay.create_executor(mod=module).evaluate()(x_input, y_input)\n",
        "print(\"z: {}\".format(z_output))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## 添加自定义数据类型\n",
        "\n",
        "现在，我们将做同样的事情，但是我们将为中间计算使用自定义数据类型。\n",
        "\n",
        "我们使用与上面相同的输入变量 ``x`` 和 ``y``，但在添加 ``x + y`` 之前，我们首先通过 ``relay.cast(...)`` call 将 ``x`` 和 ``y`` 强制转换为自定义数据类型。\n",
        "\n",
        "注意我们如何指定自定义数据类型：我们使用特殊的 ``custom[...]`` 语法。另外，注意数据类型后面的“32”：这是自定义数据类型的 bitwidth。这告诉 TVM ``myfloat`` 的每个实例都是 32 位宽的。\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {
        "collapsed": false
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n"
          ]
        }
      ],
      "source": [
        "try:\n",
        "    with tvm.transform.PassContext(config={\"tir.disable_vectorize\": True}):\n",
        "        x_myfloat = relay.cast(x, dtype=\"custom[myfloat]32\")\n",
        "        y_myfloat = relay.cast(y, dtype=\"custom[myfloat]32\")\n",
        "        z_myfloat = x_myfloat + y_myfloat\n",
        "        z = relay.cast(z_myfloat, dtype=\"float32\")\n",
        "except tvm.TVMError as e:\n",
        "    # Print last line of error\n",
        "    print(str(e).split(\"\\n\")[-1])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Trying to generate this program throws an error from TVM.\n",
        "TVM does not know how to handle any custom datatype out of the box!\n",
        "We first have to register the custom type with TVM, giving it a name and a type code:\n",
        "\n",
        "试图从 TVM 生成此程序会抛出一个错误。TVM 不知道如何处理任何开箱即用的自定义数据类型！我们首先要向 TVM 注册自定义类型，给它一个名称和一个类型代码："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 5,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "tvm.target.datatype.register(\"myfloat\", 150)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Note that the type code, 150, is currently chosen manually by the user.\n",
        "See ``TVMTypeCode::kCustomBegin`` in [include/tvm/runtime/c_runtime_api.h](https://github.com/apache/tvm/blob/main/include/tvm/runtime/data_type.h).\n",
        "Now we can generate our program again:\n",
        "\n",
        "注意，type 代码 150 目前是由用户手动选择的。参见 [include/tvm/runtime/c_runtime_api.h](https://github.com/apache/tvm/blob/main/include/tvm/runtime/data_type.h) 中的 ``TVMTypeCode::kCustomBegin``。现在我们可以再次生成我们的程序:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 6,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "x_myfloat = relay.cast(x, dtype=\"custom[myfloat]32\")\n",
        "y_myfloat = relay.cast(y, dtype=\"custom[myfloat]32\")\n",
        "z_myfloat = x_myfloat + y_myfloat\n",
        "z = relay.cast(z_myfloat, dtype=\"float32\")\n",
        "program = relay.Function([x, y], z)\n",
        "module = tvm.IRModule.from_expr(program)\n",
        "module = relay.transform.InferType()(module)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "现在我们有了一个使用 ``myfloat`` 的 Relay 程序！"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 7,
      "metadata": {
        "collapsed": false
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "fn (%x: Tensor[(3), float32], %y: Tensor[(3), float32]) {\n",
            "  %0 = cast(%x, dtype=\"custom[myfloat]32\");\n",
            "  %1 = cast(%y, dtype=\"custom[myfloat]32\");\n",
            "  %2 = add(%0, %1);\n",
            "  cast(%2, dtype=\"float32\")\n",
            "}\n"
          ]
        }
      ],
      "source": [
        "print(program)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "现在我们可以无错误地表示我们的程序，让我们试着运行它！"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 8,
      "metadata": {
        "collapsed": false
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "  Check failed: (lower) is false: Cast lowering function for target llvm destination type 150 source type 2 not found\n"
          ]
        }
      ],
      "source": [
        "try:\n",
        "    with tvm.transform.PassContext(config={\"tir.disable_vectorize\": True}):\n",
        "        z_output_myfloat = relay.create_executor(\"graph\", mod=module).evaluate()(x_input, y_input)\n",
        "        print(\"z: {}\".format(y_myfloat))\n",
        "except tvm.TVMError as e:\n",
        "    # Print last line of error\n",
        "    print(str(e).split(\"\\n\")[-1])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Now, trying to compile this program throws an error.\n",
        "Let's dissect this error.\n",
        "\n",
        "The error is occurring during the process of lowering the custom datatype code to code that TVM can compile and run.\n",
        "TVM is telling us that it cannot find a *lowering function* for the ``Cast`` operation, when casting from source type 2 (``float``, in TVM), to destination type 150 (our custom datatype).\n",
        "When lowering custom datatypes, if TVM encounters an operation over a custom datatype, it looks for a user-registered *lowering function*, which tells it how to lower the operation to an operation over datatypes it understands.\n",
        "We have not told TVM how to lower ``Cast`` operations for our custom datatypes; thus, the source of this error.\n",
        "\n",
        "To fix this error, we simply need to specify a lowering function:\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 9,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "tvm.target.datatype.register_op(\n",
        "    tvm.target.datatype.create_lower_func(\n",
        "        {\n",
        "            (32, 32): \"FloatToCustom32\",  # cast from float32 to myfloat32\n",
        "        }\n",
        "    ),\n",
        "    \"Cast\",\n",
        "    \"llvm\",\n",
        "    \"float\",\n",
        "    \"myfloat\",\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "The ``register_op(...)`` call takes a lowering function, and a number of parameters which specify exactly the operation which should be lowered with the provided lowering function.\n",
        "In this case, the arguments we pass specify that this lowering function is for lowering a ``Cast`` from ``float`` to ``myfloat`` for target ``\"llvm\"``.\n",
        "\n",
        "The lowering function passed into this call is very general: it should take an operation of the specified type (in this case, `Cast`) and return another operation which only uses datatypes which TVM understands.\n",
        "\n",
        "In the general case, we expect users to implement operations over their custom datatypes using calls to an external library.\n",
        "In our example, our ``myfloat`` library implements a ``Cast`` from ``float`` to 32-bit ``myfloat`` in the function ``FloatToCustom32``.\n",
        "To provide for the general case, we have made a helper function, ``create_lower_func(...)``,\n",
        "which does just this: given a dictionary, it replaces the given operation with a ``Call`` to the appropriate function name provided based on the op and the bit widths.\n",
        "It additionally removes usages of the custom datatype by storing the custom datatype in an opaque ``uint`` of the appropriate width; in our case, a ``uint32_t``.\n",
        "For more information, see [the source code](https://github.com/apache/tvm/blob/main/python/tvm/target/datatype.py).\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 10,
      "metadata": {
        "collapsed": false
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "  Check failed: (lower) is false: Add lowering function for target llvm type 150 not found\n"
          ]
        }
      ],
      "source": [
        "# We can now re-try running the program:\n",
        "try:\n",
        "    with tvm.transform.PassContext(config={\"tir.disable_vectorize\": True}):\n",
        "        z_output_myfloat = relay.create_executor(\"graph\", mod=module).evaluate()(x_input, y_input)\n",
        "        print(\"z: {}\".format(z_output_myfloat))\n",
        "except tvm.TVMError as e:\n",
        "    # Print last line of error\n",
        "    print(str(e).split(\"\\n\")[-1])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "This new error tells us that the ``Add`` lowering function is not found, which is good news, as it's no longer complaining about the ``Cast``!\n",
        "We know what to do from here: we just need to register the lowering functions for the other operations in our program.\n",
        "\n",
        "Note that for ``Add``, ``create_lower_func`` takes in a dict where the key is an integer.\n",
        "For ``Cast`` operations, we require a 2-tuple to specify the ``src_bit_length`` and the ``dest_bit_length``,\n",
        "while for all other operations, the bit length is the same between the operands so we only require one integer to specify ``bit_length``.\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 11,
      "metadata": {
        "collapsed": false
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "z: [0.7996937 1.168008  1.4516819]\n",
            "x:\t\t[0.51729786 0.9469626  0.7654598 ]\n",
            "y:\t\t[0.28239584 0.22104536 0.6862221 ]\n",
            "z (float32):\t[0.7996937 1.168008  1.4516819]\n",
            "z (myfloat32):\t[0.7996937 1.168008  1.4516819]\n"
          ]
        }
      ],
      "source": [
        "tvm.target.datatype.register_op(\n",
        "    tvm.target.datatype.create_lower_func({32: \"Custom32Add\"}),\n",
        "    \"Add\",\n",
        "    \"llvm\",\n",
        "    \"myfloat\",\n",
        ")\n",
        "tvm.target.datatype.register_op(\n",
        "    tvm.target.datatype.create_lower_func({(32, 32): \"Custom32ToFloat\"}),\n",
        "    \"Cast\",\n",
        "    \"llvm\",\n",
        "    \"myfloat\",\n",
        "    \"float\",\n",
        ")\n",
        "\n",
        "# Now, we can run our program without errors.\n",
        "with tvm.transform.PassContext(config={\"tir.disable_vectorize\": True}):\n",
        "    z_output_myfloat = relay.create_executor(mod=module).evaluate()(x_input, y_input)\n",
        "print(\"z: {}\".format(z_output_myfloat))\n",
        "\n",
        "print(\"x:\\t\\t{}\".format(x_input))\n",
        "print(\"y:\\t\\t{}\".format(y_input))\n",
        "print(\"z (float32):\\t{}\".format(z_output))\n",
        "print(\"z (myfloat32):\\t{}\".format(z_output_myfloat))\n",
        "\n",
        "# Perhaps as expected, the ``myfloat32`` results and ``float32`` are exactly the same!"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Running Models With Custom Datatypes\n",
        "\n",
        "We will first choose the model which we would like to run with myfloat.\n",
        "In this case we use [Mobilenet](https://arxiv.org/abs/1704.04861).\n",
        "We choose Mobilenet due to its small size.\n",
        "In this alpha state of the Bring Your Own Datatypes framework, we have not implemented any software optimizations for running software emulations of custom datatypes; the result is poor performance due to many calls into our datatype emulation library.\n",
        "\n",
        "First let us define two helper functions to get the mobilenet model and a cat image.\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 12,
      "metadata": {
        "collapsed": false
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Downloading /home/xinet/.mxnet/models/mobilenet0.25-9f83e440.zipeb3c4f5d-55fd-40c7-b2a0-1981acc156d2 from https://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/gluon/models/mobilenet0.25-9f83e440.zip...\n"
          ]
        }
      ],
      "source": [
        "def get_mobilenet():\n",
        "    dshape = (1, 3, 224, 224)\n",
        "    from mxnet.gluon.model_zoo.vision import get_model\n",
        "\n",
        "    block = get_model(\"mobilenet0.25\", pretrained=True)\n",
        "    shape_dict = {\"data\": dshape}\n",
        "    return relay.frontend.from_mxnet(block, shape_dict)\n",
        "\n",
        "\n",
        "def get_cat_image():\n",
        "    from tvm.contrib.download import download_testdata\n",
        "    from PIL import Image\n",
        "\n",
        "    url = \"https://gist.githubusercontent.com/zhreshold/bcda4716699ac97ea44f791c24310193/raw/fa7ef0e9c9a5daea686d6473a62aacd1a5885849/cat.png\"\n",
        "    dst = \"cat.png\"\n",
        "    real_dst = download_testdata(url, dst, module=\"data\")\n",
        "    img = Image.open(real_dst).resize((224, 224))\n",
        "    # CoreML's standard model image format is BGR\n",
        "    img_bgr = np.array(img)[:, :, ::-1]\n",
        "    img = np.transpose(img_bgr, (2, 0, 1))[np.newaxis, :]\n",
        "    return np.asarray(img, dtype=\"float32\")\n",
        "\n",
        "\n",
        "module, params = get_mobilenet()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "It's easy to execute MobileNet with native TVM:\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 13,
      "metadata": {
        "collapsed": false
      },
      "outputs": [
        {
          "ename": "TypeError",
          "evalue": "create_executor() got an unexpected keyword argument 'params'",
          "output_type": "error",
          "traceback": [
            "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
            "\u001b[0;31mTypeError\u001b[0m                                 Traceback (most recent call last)",
            "\u001b[1;32m/home/xinet/workspace/lxw/tvm/xinetzone/docs/how_to/extend_tvm/bring_your_own_datatypes.ipynb Cell 27\u001b[0m in \u001b[0;36m<cell line: 1>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> <a href='vscode-notebook-cell://attached-container%2B7b22636f6e7461696e65724e616d65223a222f6f7074696d69737469635f626f7267222c2273657474696e6773223a7b22686f7374223a227373683a2f2f78696e227d7d/home/xinet/workspace/lxw/tvm/xinetzone/docs/how_to/extend_tvm/bring_your_own_datatypes.ipynb#X35sdnNjb2RlLXJlbW90ZQ%3D%3D?line=0'>1</a>\u001b[0m ex \u001b[39m=\u001b[39m tvm\u001b[39m.\u001b[39;49mrelay\u001b[39m.\u001b[39;49mcreate_executor(\u001b[39m\"\u001b[39;49m\u001b[39mgraph\u001b[39;49m\u001b[39m\"\u001b[39;49m, mod\u001b[39m=\u001b[39;49mmodule, params\u001b[39m=\u001b[39;49mparams)\n\u001b[1;32m      <a href='vscode-notebook-cell://attached-container%2B7b22636f6e7461696e65724e616d65223a222f6f7074696d69737469635f626f7267222c2273657474696e6773223a7b22686f7374223a227373683a2f2f78696e227d7d/home/xinet/workspace/lxw/tvm/xinetzone/docs/how_to/extend_tvm/bring_your_own_datatypes.ipynb#X35sdnNjb2RlLXJlbW90ZQ%3D%3D?line=1'>2</a>\u001b[0m \u001b[39minput\u001b[39m \u001b[39m=\u001b[39m get_cat_image()\n\u001b[1;32m      <a href='vscode-notebook-cell://attached-container%2B7b22636f6e7461696e65724e616d65223a222f6f7074696d69737469635f626f7267222c2273657474696e6773223a7b22686f7374223a227373683a2f2f78696e227d7d/home/xinet/workspace/lxw/tvm/xinetzone/docs/how_to/extend_tvm/bring_your_own_datatypes.ipynb#X35sdnNjb2RlLXJlbW90ZQ%3D%3D?line=2'>3</a>\u001b[0m result \u001b[39m=\u001b[39m ex\u001b[39m.\u001b[39mevaluate()(\u001b[39minput\u001b[39m)\u001b[39m.\u001b[39mnumpy()\n",
            "\u001b[0;31mTypeError\u001b[0m: create_executor() got an unexpected keyword argument 'params'"
          ]
        }
      ],
      "source": [
        "ex = tvm.relay.create_executor(\"graph\", mod=module, params=params)\n",
        "input = get_cat_image()\n",
        "result = ex.evaluate()(input).numpy()\n",
        "# print first 10 elements\n",
        "print(result.flatten()[:10])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Now, we would like to change the model to use myfloat internally. To do so, we need to convert the network. To do this, we first define a function which will help us convert tensors:\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "def convert_ndarray(dst_dtype, array):\n",
        "    \"\"\"Converts an NDArray into the specified datatype\"\"\"\n",
        "    x = relay.var(\"x\", shape=array.shape, dtype=str(array.dtype))\n",
        "    cast = relay.Function([x], x.astype(dst_dtype))\n",
        "    with tvm.transform.PassContext(config={\"tir.disable_vectorize\": True}):\n",
        "        return relay.create_executor(\"graph\").evaluate(cast)(array)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Now, to actually convert the entire network, we have written [a pass in Relay](https://github.com/gussmith23/tvm/blob/ea174c01c54a2529e19ca71e125f5884e728da6e/python/tvm/relay/frontend/change_datatype.py#L21) which simply converts all nodes within the model to use the new datatype.\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "from tvm.relay.frontend.change_datatype import ChangeDatatype\n",
        "\n",
        "src_dtype = \"float32\"\n",
        "dst_dtype = \"custom[myfloat]32\"\n",
        "\n",
        "module = relay.transform.InferType()(module)\n",
        "\n",
        "# Currently, custom datatypes only work if you run simplify_inference beforehand\n",
        "module = tvm.relay.transform.SimplifyInference()(module)\n",
        "\n",
        "# Run type inference before changing datatype\n",
        "module = tvm.relay.transform.InferType()(module)\n",
        "\n",
        "# Change datatype from float to myfloat and re-infer types\n",
        "cdtype = ChangeDatatype(src_dtype, dst_dtype)\n",
        "expr = cdtype.visit(module[\"main\"])\n",
        "module = tvm.relay.transform.InferType()(module)\n",
        "\n",
        "# We also convert the parameters:\n",
        "params = {k: convert_ndarray(dst_dtype, v) for k, v in params.items()}\n",
        "\n",
        "# We also need to convert our input:\n",
        "input = convert_ndarray(dst_dtype, input)\n",
        "\n",
        "# Finally, we can try to run the converted model:\n",
        "try:\n",
        "    # Vectorization is not implemented with custom datatypes.\n",
        "    with tvm.transform.PassContext(config={\"tir.disable_vectorize\": True}):\n",
        "        result_myfloat = tvm.relay.create_executor(\"graph\", mod=module).evaluate(expr)(\n",
        "            input, **params\n",
        "        )\n",
        "except tvm.TVMError as e:\n",
        "    print(str(e).split(\"\\n\")[-1])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "When we attempt to run the model, we get a familiar error telling us that more functions need to be registered for myfloat.\n",
        "\n",
        "Because this is a neural network, many more operations are required.\n",
        "Here, we register all the needed functions:\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "tvm.target.datatype.register_op(\n",
        "    tvm.target.datatype.create_lower_func({32: \"FloatToCustom32\"}),\n",
        "    \"FloatImm\",\n",
        "    \"llvm\",\n",
        "    \"myfloat\",\n",
        ")\n",
        "\n",
        "tvm.target.datatype.register_op(\n",
        "    tvm.target.datatype.lower_ite, \"Call\", \"llvm\", \"myfloat\", intrinsic_name=\"tir.if_then_else\"\n",
        ")\n",
        "\n",
        "tvm.target.datatype.register_op(\n",
        "    tvm.target.datatype.lower_call_pure_extern,\n",
        "    \"Call\",\n",
        "    \"llvm\",\n",
        "    \"myfloat\",\n",
        "    intrinsic_name=\"tir.call_pure_extern\",\n",
        ")\n",
        "\n",
        "tvm.target.datatype.register_op(\n",
        "    tvm.target.datatype.create_lower_func({32: \"Custom32Mul\"}),\n",
        "    \"Mul\",\n",
        "    \"llvm\",\n",
        "    \"myfloat\",\n",
        ")\n",
        "tvm.target.datatype.register_op(\n",
        "    tvm.target.datatype.create_lower_func({32: \"Custom32Div\"}),\n",
        "    \"Div\",\n",
        "    \"llvm\",\n",
        "    \"myfloat\",\n",
        ")\n",
        "\n",
        "tvm.target.datatype.register_op(\n",
        "    tvm.target.datatype.create_lower_func({32: \"Custom32Sqrt\"}),\n",
        "    \"Call\",\n",
        "    \"llvm\",\n",
        "    \"myfloat\",\n",
        "    intrinsic_name=\"tir.sqrt\",\n",
        ")\n",
        "\n",
        "tvm.target.datatype.register_op(\n",
        "    tvm.target.datatype.create_lower_func({32: \"Custom32Sub\"}),\n",
        "    \"Sub\",\n",
        "    \"llvm\",\n",
        "    \"myfloat\",\n",
        ")\n",
        "\n",
        "tvm.target.datatype.register_op(\n",
        "    tvm.target.datatype.create_lower_func({32: \"Custom32Exp\"}),\n",
        "    \"Call\",\n",
        "    \"llvm\",\n",
        "    \"myfloat\",\n",
        "    intrinsic_name=\"tir.exp\",\n",
        ")\n",
        "\n",
        "tvm.target.datatype.register_op(\n",
        "    tvm.target.datatype.create_lower_func({32: \"Custom32Max\"}),\n",
        "    \"Max\",\n",
        "    \"llvm\",\n",
        "    \"myfloat\",\n",
        ")\n",
        "\n",
        "tvm.target.datatype.register_min_func(\n",
        "    tvm.target.datatype.create_min_lower_func({32: \"MinCustom32\"}, \"myfloat\"),\n",
        "    \"myfloat\",\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Note we are making use of two new functions: ``register_min_func`` and ``create_min_lower_func``.\n",
        "\n",
        "``register_min_func`` takes in an integer ``num_bits`` for the bit length, and should return an operation\n",
        "representing the minimum finite representable value for the custom data type with the specified bit length.\n",
        "\n",
        "Similar to ``register_op`` and ``create_lower_func``, the ``create_min_lower_func`` handles the general case\n",
        "where the minimum representable custom datatype value is implemented using calls to an external library.\n",
        "\n",
        "Now we can finally run the model:\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "# Vectorization is not implemented with custom datatypes.\n",
        "with tvm.transform.PassContext(config={\"tir.disable_vectorize\": True}):\n",
        "    result_myfloat = relay.create_executor(mod=module).evaluate(expr)(input, **params)\n",
        "    result_myfloat = convert_ndarray(src_dtype, result_myfloat).numpy()\n",
        "    # print first 10 elements\n",
        "    print(result_myfloat.flatten()[:10])\n",
        "\n",
        "# Again, note that the output using 32-bit myfloat exactly the same as 32-bit floats,\n",
        "# because myfloat is exactly a float!\n",
        "np.testing.assert_array_equal(result, result_myfloat)"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3.8.13 ('tvm80')",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.8.13"
    },
    "vscode": {
      "interpreter": {
        "hash": "4d56931767869a22775ecec95b6db9cb1d4d3c8e9a80ba4de55eaee096d239e9"
      }
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
